{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PyMSMvwSAR7B"
      },
      "source": [
        "# MathWriting Dataset\n",
        "\n",
        "This notebook contains code snippets showing how to read data from the MathWriting dataset, and do some minimal processing.\n\nLicensed under the Apache License, Version 2.0\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import dataclasses\n",
        "import json\n",
        "import os\n",
        "import pprint\n",
        "import re\n",
        "\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as pl\n",
        "import matplotlib.patches as mpl_patches\n",
        "\n",
        "from xml.etree import ElementTree"
      ],
      "metadata": {
        "id": "_m8W0Gj0eAAP"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Download\n",
        "\n",
        "Download the data and unpack it.\n"
      ],
      "metadata": {
        "id": "ZBB1VjlLzcnP"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!wget -nc https://storage.googleapis.com/mathwriting_data/mathwriting-2024-excerpt.tgz\n",
        "!tar zxf mathwriting-2024-excerpt.tgz\n",
        "!ls mathwriting-2024-excerpt\n",
        "\n",
        "MATHWRITING_ROOT_DIR='mathwriting-2024-excerpt'"
      ],
      "metadata": {
        "id": "_buCRTuQdbkF"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# To use the full dataset, uncomment and run this cell instead of the one above.\n",
        "#!wget -nc https://storage.googleapis.com/mathwriting_data/mathwriting-2024.tgz\n",
        "#!tar zxf mathwriting-2024.tgz\n",
        "#!ls mathwriting-2024\n",
        "#MATHWRITING_ROOT_DIR='mathwriting-2024'"
      ],
      "metadata": {
        "id": "dDDC5P3Hdl0V"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Reading InkML files\n",
        "\n",
        "All inks are provided in InkML format. The code below shows how to read the specific structure used by MathWriting."
      ],
      "metadata": {
        "id": "hZDNaBh2dupu"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "@dataclasses.dataclass\n",
        "class Ink:\n",
        "  \"\"\"Represents a single ink, as read from an InkML file.\"\"\"\n",
        "  # Every stroke in the ink.\n",
        "  # Each stroke array has shape (3, number of points), where the first\n",
        "  # dimensions are (x, y, timestamp), in that order.\n",
        "  strokes: list[np.ndarray]\n",
        "  # Metadata present in the InkML.\n",
        "  annotations: dict[str, str]\n",
        "\n",
        "\n",
        "def read_inkml_file(filename: str) -> Ink:\n",
        "  \"\"\"Simple reader for MathWriting's InkML files.\"\"\"\n",
        "  with open(filename, \"r\") as f:\n",
        "    root = ElementTree.fromstring(f.read())\n",
        "\n",
        "  strokes = []\n",
        "  annotations = {}\n",
        "\n",
        "  for element in root:\n",
        "    tag_name = element.tag.removeprefix('{http://www.w3.org/2003/InkML}')\n",
        "    if tag_name == 'annotation':\n",
        "      annotations[element.attrib.get('type')] = element.text\n",
        "\n",
        "    elif tag_name == 'trace':\n",
        "      points = element.text.split(',')\n",
        "      stroke_x, stroke_y, stroke_t = [], [], []\n",
        "      for point in points:\n",
        "        x, y, t = point.split(' ')\n",
        "        stroke_x.append(float(x))\n",
        "        stroke_y.append(float(y))\n",
        "        stroke_t.append(float(t))\n",
        "      strokes.append(np.array((stroke_x, stroke_y, stroke_t)))\n",
        "\n",
        "  return Ink(strokes=strokes, annotations=annotations)\n",
        "\n",
        "\n",
        "def display_ink(\n",
        "    ink: Ink,\n",
        "    *,\n",
        "    figsize: tuple[int, int]=(15, 10),\n",
        "    linewidth: int=2,\n",
        "    color=None):\n",
        "  \"\"\"Simple display for a single ink.\"\"\"\n",
        "  pl.figure(figsize=figsize)\n",
        "  for stroke in ink.strokes:\n",
        "    pl.plot(stroke[0], stroke[1], linewidth=linewidth, color=color)\n",
        "    pl.title(\n",
        "        f\"{ink.annotations.get('sampleId', '')} -- \"\n",
        "        f\"{ink.annotations.get('splitTagOriginal', '')} -- \"\n",
        "        f\"{ink.annotations.get('normalizedLabel', ink.annotations['label'])}\"\n",
        "    )\n",
        "  pl.gca().invert_yaxis()\n",
        "  pl.gca().axis('equal')"
      ],
      "metadata": {
        "id": "z9QXBDONdyG1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "ink = read_inkml_file(os.path.join(MATHWRITING_ROOT_DIR, 'train', '060ab59eeba54d2f.inkml'))\n",
        "pprint.pprint(ink.annotations)\n",
        "display_ink(ink, color='black')"
      ],
      "metadata": {
        "id": "nxqCqCAaf2WG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Reading Bounding Box Files\n",
        "\n",
        "These files contain bounding boxes with the symbol that should be displayed in it, plus the LaTeX string for the overall expression. The format used is JSONL: each line in the file is a separate JSON object containing all the bounding boxes for a single mathematical expression.\n",
        "\n",
        "The code below shows how to turn each of these lines into a Python class."
      ],
      "metadata": {
        "id": "BejIHtwQj3yi"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "@dataclasses.dataclass\n",
        "class BoundingBox:\n",
        "  \"\"\"A bounding box for a single symbol.\"\"\"\n",
        "  # The symbol that should appear in the bounding box, in LaTeX notation.\n",
        "  token: str\n",
        "  # Definition of the bounding box.\n",
        "  x_min: float\n",
        "  y_min: float\n",
        "  x_max: float\n",
        "  y_max: float\n",
        "\n",
        "\n",
        "@dataclasses.dataclass\n",
        "class BoundingBoxes:\n",
        "  \"\"\"A set of bounding boxes for a complete mathematical expression.\"\"\"\n",
        "  # The complete (raw) LaTeX string.\n",
        "  label: str\n",
        "  # A normalized version of `label`.\n",
        "  normalized_label: str|None\n",
        "  # Bounding boxes for all symbols in the expression, in no particular order.\n",
        "  bboxes: list[BoundingBox]\n",
        "\n",
        "\n",
        "def read_bbox_file(filename: str, *, index: int=0) -> BoundingBoxes:\n",
        "  \"\"\"Reads a single bounding box from the input file.\"\"\"\n",
        "  with open(filename, \"r\") as f:\n",
        "    for n, line in enumerate(f):\n",
        "      if n < index:\n",
        "        continue\n",
        "\n",
        "      bboxes = json.loads(line)\n",
        "\n",
        "      symbol_bboxes = []\n",
        "      for bbox in bboxes['bboxes']:\n",
        "        symbol_bboxes.append(BoundingBox(\n",
        "            token=bbox['token'],\n",
        "            x_min=bbox['xMin'],\n",
        "            y_min=bbox['yMin'],\n",
        "            x_max=bbox['xMax'],\n",
        "            y_max=bbox['yMax']))\n",
        "\n",
        "      return BoundingBoxes(\n",
        "              label=bboxes['label'],\n",
        "              normalized_label=bboxes.get('normalizedLabel', None),\n",
        "              bboxes = symbol_bboxes)\n",
        "\n",
        "\n",
        "def display_bboxes(bboxes: BoundingBoxes, *, figsize: tuple[int, int]=(15, 10)):\n",
        "  \"\"\"Displays a set of bounding boxes for debugging purposes.\"\"\"\n",
        "  fig = pl.figure(figsize=figsize)\n",
        "  ax = fig.gca()\n",
        "  x_min = float('inf')\n",
        "  y_min = float('inf')\n",
        "  x_max = -float('inf')\n",
        "  y_max = -float('inf')\n",
        "  for bbox in bboxes.bboxes:\n",
        "    x_min = min(x_min, bbox.x_min)\n",
        "    y_min = min(y_min, bbox.y_min)\n",
        "    x_max = max(x_max, bbox.x_max)\n",
        "    y_max = max(y_max, bbox.y_max)\n",
        "\n",
        "    ax.add_patch(\n",
        "        mpl_patches.Polygon(\n",
        "            ((bbox.x_min,bbox.y_min),\n",
        "            (bbox.x_min,bbox.y_max),\n",
        "            (bbox.x_max,bbox.y_max),\n",
        "            (bbox.x_max, bbox.y_min)),\n",
        "            closed=True,\n",
        "            facecolor='none',\n",
        "            edgecolor='darkblue',\n",
        "            linewidth=2))\n",
        "\n",
        "  width = x_max - x_min\n",
        "  height = y_max - y_min\n",
        "\n",
        "  # We put the text in one go at the end to be able to scale it properly.\n",
        "  for bbox in bboxes.bboxes:\n",
        "    box_width = bbox.x_max - bbox.x_min\n",
        "    box_height = bbox.y_max - bbox.y_min\n",
        "    if bbox.token != r'\\frac':\n",
        "      ax.text(bbox.x_min+box_width/2,\n",
        "              bbox.y_min+box_height/2,\n",
        "              bbox.token,\n",
        "              verticalalignment='center',\n",
        "              horizontalalignment='center',\n",
        "              fontsize=100000 / max(width, height))\n",
        "\n",
        "  margin_ratio = 0.1\n",
        "  ax.set_xlim(x_min-margin_ratio*width, x_max+margin_ratio*width)\n",
        "  ax.set_ylim(y_max+margin_ratio*height, y_min-margin_ratio*height)\n",
        "  pl.title(bboxes.normalized_label or bboxes.label)\n"
      ],
      "metadata": {
        "id": "kB2s1KyAj76D"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "bboxes = read_bbox_file(\n",
        "    os.path.join(MATHWRITING_ROOT_DIR, 'synthetic-bboxes.jsonl'),\n",
        "    index=2)\n",
        "display_bboxes(bboxes)"
      ],
      "metadata": {
        "id": "zfrsWymUl3sP"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Extracting Individual Symbols\n",
        "\n",
        "For the purpose of synthesizing inks from bounding boxes, isolated symbols are provided as inks. These were extracted from inks from the training set, and MathWriting contains all the information necessary to recompute this extraction.\n",
        "The code below shows how to do it.\n"
      ],
      "metadata": {
        "id": "7yLOlewTuq29"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "@dataclasses.dataclass\n",
        "class InkPart:\n",
        "  \"\"\"A pointer to a part of an ink corresponding to a single symbol.\"\"\"\n",
        "  # Which sample the symbol is from. Ex: '00016221aae38d32'\n",
        "  source_sample_id: str\n",
        "  # Which symbol it is. Ex: '\\sum'\n",
        "  label: str\n",
        "  # Indices of strokes in the source ink that cover the symbol.\n",
        "  stroke_indices: list[int]\n",
        "\n",
        "\n",
        "def read_symbols_file(filename: str) -> list[InkPart]:\n",
        "  \"\"\"Reads all the symbol definitions.\"\"\"\n",
        "  symbols = []\n",
        "  with open(filename, \"r\") as f:\n",
        "    for line in f:\n",
        "      symbol_json = json.loads(line)\n",
        "      symbols.append(InkPart(\n",
        "          source_sample_id=symbol_json['sourceSampleId'],\n",
        "          label=symbol_json['label'],\n",
        "          stroke_indices=symbol_json['strokeIndices']))\n",
        "  return symbols\n",
        "\n",
        "\n",
        "def get_symbol_ink(symbol: InkPart) -> Ink:\n",
        "  \"\"\"Computes the actual ink from an InkPart object.\"\"\"\n",
        "  ink = read_inkml_file(\n",
        "      os.path.join(MATHWRITING_ROOT_DIR, \"train\", f\"{symbol.source_sample_id}.inkml\"))\n",
        "  strokes = [ink.strokes[i] for i in symbol.stroke_indices]\n",
        "  return Ink(\n",
        "      strokes=strokes,\n",
        "      annotations={\n",
        "          'label': symbol.label,\n",
        "          'splitTagOriginal': ink.annotations['splitTagOriginal']})"
      ],
      "metadata": {
        "id": "bxDrnBr5CadX"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Extract one isolated symbol and display it.\n",
        "symbols = read_symbols_file(os.path.join(MATHWRITING_ROOT_DIR, 'symbols.jsonl'))\n",
        "symbol_ink = get_symbol_ink(symbols[3902])\n",
        "display_ink(symbol_ink, color='black')"
      ],
      "metadata": {
        "id": "dZ1RtGcfDH7m"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Ink Rasterization\n",
        "\n",
        "This shows how to use Cairo for high-quality rendering of inks."
      ],
      "metadata": {
        "id": "0SkKo1kjfIHZ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!apt-get install libcairo2-dev libjpeg-dev libgif-dev\n",
        "!pip install pycairo"
      ],
      "metadata": {
        "id": "V-VWq495wYAu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import cairo\n",
        "import math\n",
        "import PIL\n",
        "import PIL.Image\n",
        "\n",
        "\n",
        "def cairo_to_pil(surface: cairo.ImageSurface) -> PIL.Image.Image:\n",
        "  \"\"\"Converts a ARGB Cairo surface into an RGB PIL image.\"\"\"\n",
        "  size = (surface.get_width(), surface.get_height())\n",
        "  stride = surface.get_stride()\n",
        "  with surface.get_data() as memory:\n",
        "    return PIL.Image.frombuffer(\n",
        "        'RGB', size, memory.tobytes(), 'raw', 'BGRX', stride\n",
        "    )\n",
        "\n",
        "\n",
        "def render_ink(\n",
        "    ink: Ink,\n",
        "    *,\n",
        "    margin: int = 10,\n",
        "    stroke_width: float = 1.5,\n",
        "    stroke_color: tuple[float, float, float] = (0.0, 0.0, 0.0),\n",
        "    background_color: tuple[float, float, float] = (1.0, 1.0, 1.0),\n",
        ") -> PIL.Image.Image:\n",
        "  \"\"\"Renders an ink as a PIL image using Cairo.\n",
        "\n",
        "  The image size is chosen to fit the entire ink while having one pixel per\n",
        "  InkML unit.\n",
        "\n",
        "  Args:\n",
        "    margin: size of the blank margin around the image (pixels)\n",
        "    stroke_width: width of each stroke (pixels)\n",
        "    stroke_color: color to paint the strokes with\n",
        "    background_color: color to fill the background with\n",
        "\n",
        "  Returns:\n",
        "    Rendered ink, as a PIL image.\n",
        "  \"\"\"\n",
        "\n",
        "  # Compute transformation to fit the ink in the image.\n",
        "  xmin, ymin = np.vstack([stroke[:2].min(axis=1) for stroke in ink.strokes]).min(axis=0)\n",
        "  xmax, ymax = np.vstack([stroke[:2].max(axis=1) for stroke in ink.strokes]).max(axis=0)\n",
        "  width = int(xmax - xmin + 2*margin)\n",
        "  height = int(ymax - ymin + 2*margin)\n",
        "\n",
        "  shift_x = - xmin + margin\n",
        "  shift_y = - ymin + margin\n",
        "\n",
        "  def apply_transform(ink_x: float, ink_y: float):\n",
        "    return ink_x + shift_x, ink_y + shift_y\n",
        "\n",
        "  # Create the canvas with the background color\n",
        "  surface = cairo.ImageSurface(cairo.FORMAT_ARGB32, width, height)\n",
        "  ctx = cairo.Context(surface)\n",
        "  ctx.set_source_rgb(*background_color)\n",
        "  ctx.paint()\n",
        "\n",
        "  # Set pen parameters\n",
        "  ctx.set_source_rgb(*stroke_color)\n",
        "  ctx.set_line_width(stroke_width)\n",
        "  ctx.set_line_cap(cairo.LineCap.ROUND)\n",
        "  ctx.set_line_join(cairo.LineJoin.ROUND)\n",
        "\n",
        "  for stroke in ink.strokes:\n",
        "    if len(stroke[0]) == 1:\n",
        "      # For isolated points we just draw a filled disk with a diameter equal\n",
        "      # to the line width.\n",
        "      x, y = apply_transform(stroke[0, 0], stroke[1, 0])\n",
        "      ctx.arc(x, y, stroke_width / 2, 0, 2 * math.pi)\n",
        "      ctx.fill()\n",
        "\n",
        "    else:\n",
        "      ctx.move_to(*apply_transform(stroke[0,0], stroke[1,0]))\n",
        "\n",
        "      for ink_x, ink_y in stroke[:2, 1:].T:\n",
        "        ctx.line_to(*apply_transform(ink_x, ink_y))\n",
        "      ctx.stroke()\n",
        "\n",
        "  return cairo_to_pil(surface)\n",
        "\n",
        "ink = read_inkml_file(os.path.join(MATHWRITING_ROOT_DIR, 'train', '0668dd347d600906.inkml'))"
      ],
      "metadata": {
        "id": "hvW6hlY4fHBt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "render_ink(ink)"
      ],
      "metadata": {
        "id": "Lo3obrPOw9za"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "render_ink(\n",
        "    ink,\n",
        "    stroke_width=5.0,\n",
        "    stroke_color=(.8, .1, .1),\n",
        "    background_color=(.96, .96, .86),\n",
        "  )"
      ],
      "metadata": {
        "id": "TyUfl8SExId1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Model Evaluation\n",
        "\n",
        "A few functions to compute the character error rate for model performance evaluation."
      ],
      "metadata": {
        "id": "ZNFYMexR0b-h"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Tokenizer\n",
        "\n",
        "Function to split a LaTeX expressions into tokens. These tokens can be used for model training or model evaluation."
      ],
      "metadata": {
        "id": "2YqK3TR_3A5N"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "_COMMAND_RE = re.compile(r'\\\\(mathbb{[a-zA-Z]}|begin{[a-z]+}|end{[a-z]+}|operatorname\\*|[a-zA-Z]+|.)')\n",
        "\n",
        "\n",
        "def tokenize_expression(s: str) -> list[str]:\n",
        "  r\"\"\"Transform a Latex math string into a list of tokens.\n",
        "\n",
        "  Tokens are strings that are meaningful in the context of Latex\n",
        "  e.g. '1', r'\\alpha', r'\\frac'.\n",
        "\n",
        "  Args:\n",
        "    s: unicode input string (ex: r\"\\frac{1}{2}\")\n",
        "\n",
        "  Returns:\n",
        "    tokens: list of tokens as unicode strings.\n",
        "  \"\"\"\n",
        "  tokens = []\n",
        "  while s:\n",
        "    if s[0] == '\\\\':\n",
        "      tokens.append(_COMMAND_RE.match(s).group(0))\n",
        "    else:\n",
        "      tokens.append(s[0])\n",
        "\n",
        "    s = s[len(tokens[-1]) :]\n",
        "\n",
        "  return tokens\n",
        "\n",
        "\n",
        "# Example\n",
        "print(tokenize_expression(r'\\frac{\\alpha}{2}\\begin{matrix}1&0\\\\0&1\\end{matrix}\\not\\in\\mathbb{R}'))"
      ],
      "metadata": {
        "id": "_9JiZqn93D_0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### CER Computation\n",
        "\n",
        "Computation of the Character Error Rate."
      ],
      "metadata": {
        "id": "Noulm6cy0tUq"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install jiwer\n",
        "import jiwer"
      ],
      "metadata": {
        "id": "h-K9tkXx1SJ1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def compute_cer(truth_and_output: list[tuple[str, str]]):\n",
        "  \"\"\"Computes CER given pairs of ground truth and model output.\"\"\"\n",
        "  class TokenizeTransform(jiwer.transforms.AbstractTransform):\n",
        "    def process_string(self, s: str):\n",
        "      return tokenize_expression(r'{}'.format(s))\n",
        "\n",
        "    def process_list(self, tokens: list[str]):\n",
        "      return [self.process_string(token) for token in tokens]\n",
        "\n",
        "  ground_truth, model_output = zip(*truth_and_output)\n",
        "\n",
        "  return jiwer.cer(truth=list(ground_truth),\n",
        "            hypothesis=list(model_output),\n",
        "            reference_transform=TokenizeTransform(),\n",
        "            hypothesis_transform=TokenizeTransform(),\n",
        "      )"
      ],
      "metadata": {
        "id": "HTFANiQJ2Ppv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Test data to run compute_cer().\n",
        "# The first element is the model prediction, the second the ground truth.\n",
        "examples = [\n",
        "    (r'\\sqrt{2}', r'\\sqrt{2}'),  # 0 mistakes, 4 tokens\n",
        "    (r'\\frac{1}{2}', r'\\frac{i}{2}'),  # 1 mistake, 7 tokens\n",
        "    (r'\\alpha^{2}', 'a^{2}'),  # 1 mistake, 5 tokens\n",
        "    ('abc', 'def'),  # 3 mistakes, 3 tokens\n",
        "]\n",
        "\n",
        "# 5 mistakes for 19 tokens: 26.3% error rate.\n",
        "print(f\"{compute_cer(examples)*100:.1f} %\")"
      ],
      "metadata": {
        "id": "uM3AXFLw0yMM"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}